-
Couldn't load subscription status.
- Fork 51
Added proactive heartbeat timeout failure propagation (#164) (#188) #196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ebb3953 to
2a7bac7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good! I need to do another fine grained pass in case I missed something + small style stuff
| fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { | ||
| slf | ||
| } | ||
| fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just use self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyRef<'_, FailureStream> cannot be used as the type of self without the arbitrary_self_types feature
see issue #44874 rust-lang/rust#44874 for more information
consider changing to self, &self, &mut self, or a type implementing Receiver such as self: Box<Self>, self: Rc<Self>, or self: Arc<Self>rustcClick for full compiler diagnostic
I get this error when use self.
| default_value = "1000", | ||
| help = "How frequently to check for failures." | ||
| )] | ||
| pub failure_tick_ms: u64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason to use separate failure_tick_ms instead of just the quorum_tick?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking of exposing this as a knob. But don't have any particular reason. Should we only have a tick_ms variable that is shared between quorum_tick and failure_tick?
|
|
||
| while not stop_event.is_set(): | ||
| try: | ||
| lighthouse_client = LighthouseClient( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've generally avoided hitting lighthouse directly from each worker and instead use a call tree through the Manager. It's probably for the best to do the same here instead of requiring all workers to hit the Lighthouse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm my understanding:
- Each ManagerClient would send a heartbeat to Lighthouse via its ManagerServer.
- If any client heartbeat times out, the ManagerServer stops forwarding heartbeats.
- Lighthouse then streams a failure notification back to the ManagerServer, which streams it to clients.
I prototyped this design at first, but in practice it duplicated the streaming logic—once in ManagerClient and again in ManagerServer—and the two paths diverged enough that the code was no longer reusable.
Instead, I adopted the simpler model where each worker talks directly to LighthouseServer, mainly because this would be easier to maintain and reason about.
Another consideration that I had is that this makes torchFT easier to extend beyond HSDP-specific scenarios (e.g., dynamically reconfiguring a pipeline after a failure in other deployment environments).
On the other hand, I agree that having the workers directly hitting the Lighthouse goes against the soft invariants maintained by the code. I would be glad to integrate the call-tree approach if you feel it adds value.
17f44f4 to
daa3adf
Compare
3641dca to
7b550aa
Compare
Overview
This PR improves failure detection speed of torchFT through proactive failure recovery. The Manager now listens to Lighthouse failure notifications and aborts hanging collectives immediately instead of waiting for NCCL/Gloo time-outs.
Basic demonstration
You can experiment with proactive failure recovery mode by:
export TORCHFT_PROACTIVE_RECOVERY=1With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.
You can test this out by running
train_ddp_proactive.pyOn shell 1 (one replica groups starts initial training):
On shell 2 (a second replica group joins):
You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting
export TORCHFT_PROACTIVE_RECOVERY=0, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.And in the Lighthouse you will observe:
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Replica train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b timed out (last heartbeat: Instant { tv_sec: 5200692, tv_nsec: 955240591 }), sending failure notification. 2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Removed replica train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b from heartbeats and participants due to timeout. 2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - New failure detected, resetting all participants for quorum formation. 2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Healthy replicas received failure notification for train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6bImplementation
Implementation Details
Implementation Details:
The proactive failure recovery mechanism involves changes in both the Rust backend and the Python
Manager:Rust:
src/lighthouse.rs:Lighthouseserver now includes afailure_channel(a Tokio broadcast channel)._failure_tickdetects a timed-out replica, it broadcasts aFailureNotificationon this channel.subscribe_failures, is added toLighthouseService. Clients can call this to receive a stream ofFailureNotifications.inject_failuremethod has been added to theLighthouseServer(Python-exposed) andLighthouse(Rust struct) to facilitate testing by manually triggering failure notifications.src/lib.rs:FailureStreamclass is introduced, wrapping thetonic::Streaming<ProtoFailureNotification>. Its__next__method allows Python to iterate over failure notifications. This method usespy.allow_threadsaround a blockingruntime.block_on(fut)call to fetch the next notification, allowing the GIL to be released.Python (Manager):
torchft/manager.py:proactive_recoveryis enabled (via constructor argument orTORCHFT_PROACTIVE_RECOVERY=1environment variable), theManagerspawns a separate daemon process (_failure_listener_process_main).Subprocess based subscription: This process creates aLighthouseClientand callssubscribe_failures. It then iterates over the received failure notifications.Inter-Process Communication (IPC):_ManagedPipeis used for the listener process to send errors it receives from theLighthousethrough the stream returned bysubscribe_failuresback to the mainManagerprocess. This mimics the implementation of IPC inBabyProcessGroupManagerprocess continuously polls the_error_pipe.self.report_error()and aborts the underlying process group (self._pg.abort()).self.report_error()is now also used to flag the manager as errored when a proactive failure is detected.Manager.shutdown()is enhanced to gracefully stop the_error_processor_threadand the_failure_listener_process.subscribe_timeoutparameter forsubscribe_failuresin_failure_listener_process_mainallows the listener process to be interruptible for clean shutdown.Design Rationale
I decided to use a separate process to subscribe to the failure notification because waiting on the failure stream is a blocking call. Because of the GIL, if one waits using a Python thread then it will block the main thread from functioning.
As I was implementing it, I considered three ways to implement this:
lib.rs.pyo3-asyncioto create an async iterator from tokio-stream.Approach 1 and 2 are more elegant and should be more efficient as they do not involve spawning a separate process. However, I am limited by my Rust langauge understanding and was unable to implement them.
Tests
I introduced the following tests:
src/lighthouse.rs:test_subscribe_failures_delivers_notifications: Verifies thatinject_failurecorrectly sends a notification that is received by a subscriber.test_failure_tick_single_notification_and_cleanup: Ensures_failure_tickcorrectly identifies timeouts, broadcasts notifications once, and cleans up state.torchft/lighthouse_test.py:test_subscribe_failures_notification: Python-level test ensuringLighthouseClient.subscribe_failuresreceives notifications triggered byLighthouseServer.inject_failure.test_inject_failure: Confirms thatserver.inject_failure()leads to a notification being received byclient.subscribe_failures().torchft/manager_test.py:test_manager_error_handler: Tests that theManagerprocesses exceptions passed to its internal error handler.test_direct_error_pipe: Verifies that an exception sent directly via the IPC pipe is correctly picked up by theManager.test_manager_failure_e2e: An end-to-end test whereLighthouseServer.inject_failuretriggers a notification that propagates through the listener process, IPC pipe, and results in theManagercapturing the error.Linter
I am still getting the following error after running
lintrunner -a, but I couldn’t debug it:Other minor changes
Note: In order to test the code using train_ddp.py, I fixed an error introduced by commit 652a009 and changed the api of DistributedSampler to use
replica_group_id.